{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DefuseOps\n",
    "\n",
    "源码文件 `tvm/src/relay/transforms/defuse_ops.cc` 是 Relay 框架中的运算，用于将融合运算（`relay::transform::FuseOps`）的结果还原为融合之前的状态。即 `x == DefuseOps(FuseOps(x))`。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "名为 `DefuseOpsMutator`，继承自 `ExprMutator`。它包含了一个嵌套的类 `FuncBodyMutator`，也继承自 `ExprMutator`。\n",
    "\n",
    "`DefuseOpsMutator` 类有两个成员函数：`VisitExpr_(const CallNode* n)` 和 `DefuseOps(const Expr& expr)`。\n",
    "\n",
    "`VisitExpr_(const CallNode* n)` 函数接受指向 `CallNode` 对象的指针作为参数，并返回 `Expr` 对象。在函数内部，首先调用父类 `ExprMutator` 的 `VisitExpr_` 函数来处理 `n`。然后，如果返回的对象是 `CallNode` 类型，就进一步检查其算子是否为 `FunctionNode` 类型。如果是，则创建无序的哈希表 `name_to_args` 来存储函数参数的名称和对应的表达式。接下来，遍历函数的参数列表，将每个参数的名称和对应的表达式添加到哈希表中。最后，使用 `FuncBodyMutator` 类的构造函数创建新的 `FuncBodyMutator` 对象，并将 `name_to_args` 作为参数传递给它。然后调用该对象的 `Mutate` 函数，传入函数体的表达式，最终返回处理后的表达式。\n",
    "\n",
    "`DefuseOps(const Expr& expr)` 函数接受 `Expr` 对象的引用作为参数，并返回 `Expr` 对象。它的作用是创建 `DefuseOpsMutator` 对象，并调用其 `Mutate` 函数来处理输入的表达式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {\n",
      "  %5 = fn (%FunctionVar_2_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_2_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_2_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern=\"nn.conv2d_add_nn.relu_\", Composite=\"ccompiler.conv_add_relu\") -> Tensor[(1, 64, 112, 112), float32] {\n",
      "    %3 = nn.conv2d(%FunctionVar_2_0, %FunctionVar_2_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "    %4 = add(%3, %FunctionVar_2_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "    nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] */\n",
      "  } /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %6 = %5(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %8 = fn (%FunctionVar_1_0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %FunctionVar_1_1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %FunctionVar_1_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern=\"nn.conv2d_add_nn.relu_\", Composite=\"ccompiler.conv_add_relu\") -> Tensor[(1, 64, 56, 56), float32] {\n",
      "    %1 = nn.conv2d(%FunctionVar_1_0, %FunctionVar_1_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "    %2 = add(%1, %FunctionVar_1_2) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "    nn.relu(%2) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
      "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %9 = %8(%7, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %10 = fn (%FunctionVar_0_0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %FunctionVar_0_1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %FunctionVar_0_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern=\"nn.conv2d_add_\", Composite=\"ccompiler.conv_add_relu\") -> Tensor[(1, 64, 56, 56), float32] {\n",
      "    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "    add(%0, %FunctionVar_0_2) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
      "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %11 = %10(%9, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  add(%11, %7) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
      "} /* ty=fn (Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 64, 56, 56), float32] */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tvm import relay\n",
    "import tvm\n",
    "from tvm_book.tvm_utils.llvm_utils import run_llvm_graph\n",
    "from tvm_book.tvm_utils.split_graph import graph_split\n",
    "from tvm.relay.dataflow_pattern import is_op, wildcard\n",
    "\n",
    "def make_conv_add_relu_pattern():\n",
    "    \"\"\"创建如下模式\n",
    "\n",
    "     conv2d\n",
    "        |\n",
    "      (add)\n",
    "        |\n",
    "      (relu)\n",
    "    \"\"\"\n",
    "    x = wildcard()\n",
    "    w = wildcard()\n",
    "    bias = wildcard()\n",
    "    r = is_op(\"nn.conv2d\")(x, w)\n",
    "    r = is_op(\"add\")(r, bias) | r\n",
    "    # 激活函数\n",
    "    r = r.optional(lambda x: is_op(\"nn.relu\")(x))\n",
    "    return r\n",
    "\n",
    "def load_model(input_shape=[1, 3, 224, 224]):\n",
    "    \"\"\"加载前端模型\"\"\"\n",
    "    import torch\n",
    "    from torchvision.models import resnet18\n",
    "    from torchvision.models.resnet import ResNet18_Weights\n",
    "    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)\n",
    "    data = torch.randn(*input_shape)\n",
    "    return torch.jit.trace(model.eval(), data)\n",
    "\n",
    "size = 224, 224\n",
    "input_shape = (1, 3, *size)\n",
    "input_name = \"data\"\n",
    "traced_model = load_model(input_shape).eval()\n",
    "# 将前端模型翻译为 relay 模型\n",
    "origin_mod, origin_params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])\n",
    "# 获取子图\n",
    "split_conf = [{\"op_name\": \"add\", \"op_index\": 0}]\n",
    "mod = graph_split(origin_mod[\"main\"], split_conf)[0]\n",
    "compiler_name = \"ccompiler\"\n",
    "pattern_table = [\n",
    "    (f\"{compiler_name}.conv_add_relu\", make_conv_add_relu_pattern()),\n",
    "]\n",
    "merge_passes = tvm.transform.Sequential([\n",
    "    relay.transform.InferType(),\n",
    "    relay.transform.MergeComposite(pattern_table),\n",
    "    # # relay.transform.AnnotateTarget([compiler_name]),\n",
    "    relay.transform.PartitionGraph(),\n",
    "])\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    with relay.quantize.qconfig(\n",
    "        calibrate_mode=\"kl_divergence\",\n",
    "        weight_scale=\"max\",\n",
    "        skip_conv_layers=[],\n",
    "        skip_dense_layer=False\n",
    "    ):\n",
    "        # 量化前准备\n",
    "        run_mod = relay.quantize.prerequisite_optimize(mod, origin_params)\n",
    "        run_mod = merge_passes(run_mod) # 算子融合\n",
    "print(run_mod[\"main\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "直接调用 {class}`tvm.relay.transform.DefuseOps`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {\n",
       "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
       "  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
       "  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
       "  %3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
       "  %4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
       "  %5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
       "  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
       "  %7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
       "  %8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
       "  add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
       "}\n"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relay.transform.DefuseOps()(run_mod)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了更好理解 {class}`tvm.relay.transform.DefuseOps` 实现功能，可以使用 Python 模拟其功能："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, field\n",
    "from tvm.relay import Call\n",
    "from tvm.relay.function import Function\n",
    "\n",
    "\n",
    "@tvm.relay.transform.function_pass(opt_level=1)\n",
    "class DefuseTransform:\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "        \n",
    "    def reset(self):\n",
    "        self.name_to_args_ = {}\n",
    "\n",
    "    def transform_function(self, func, mod, ctx):\n",
    "        obj = self\n",
    "\n",
    "        @dataclass\n",
    "        class FuncBodyMutator(tvm.relay.ExprMutator):\n",
    "            name_to_args_: dict\n",
    "            memo_map: dict = field(default_factory=dict)\n",
    "\n",
    "            def visit_var(self, var):\n",
    "                return self.name_to_args_[var.name_hint]\n",
    "\n",
    "        class Replace(tvm.relay.ExprMutator):\n",
    "            def visit_call(self, call):\n",
    "                new_fn = self.visit(call.op)\n",
    "                new_args = [self.visit(arg) for arg in call.args]\n",
    "                call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)\n",
    "                if isinstance(call.op, Function):\n",
    "                    name_to_args = {}\n",
    "                    for param, arg in zip(new_fn.params, new_args):\n",
    "                        name_to_args[param.name_hint] = arg\n",
    "                    call = FuncBodyMutator(name_to_args).visit(new_fn.body)\n",
    "                return call\n",
    "        return Replace().visit(func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {\n",
      "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
      "}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "transform = DefuseTransform()\n",
    "_mod = transform(run_mod)\n",
    "print(_mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) {\n",
      "  %0 = nn.conv2d(%data, meta[relay.Constant][0], strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %1 = nn.batch_norm(%0, meta[relay.Constant][1], meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4]) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */;\n",
      "  %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
      "  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %5 = nn.conv2d(%4, meta[relay.Constant][5], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %6 = nn.batch_norm(%5, meta[relay.Constant][6], meta[relay.Constant][7], meta[relay.Constant][8], meta[relay.Constant][9]) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;\n",
      "  %7 = %6.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %8 = nn.relu(%7) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %9 = nn.conv2d(%8, meta[relay.Constant][10], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  %10 = nn.batch_norm(%9, meta[relay.Constant][11], meta[relay.Constant][12], meta[relay.Constant][13], meta[relay.Constant][14]) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;\n",
      "  %11 = %10.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;\n",
      "  add(%11, %4) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
      "}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tvm.relay.transform.suffixes import tag_suffixes\n",
    "\n",
    "print(tag_suffixes(mod))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
